"""
Phase 10: First-Differenced Local Projections
===============================================
The honest pre-trend fix.

Level variables (Investment/GDP, MPK, K/O ratio) are first-differenced
before constructing horizons. This removes persistence and makes pre-trend
tests at h=-1,-2,-3 non-tautological.

Growth variables (Δlog K/L, ΔTFP, GDP growth) are already in differences
and are carried forward unchanged.

LP specification:
  Δy_{t+h} = α + β_h · x_t + controls_t + ε_{t+h}

where Δy_{t+h} = y_{t+h} - y_{t-1} for level vars (cumulative change from
pre-shock baseline), or cumulative sum of growth from t to t+h for growth vars.

Pre-trend test: Δy at h<0 should be null — no pre-existing change in
outcomes before the demographic inflow "shock."
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

MAX_HORIZON = 5
PRE_HORIZONS = [-3, -2, -1]
CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']

OECD = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHE', 'CHL', 'COL', 'CRI', 'CZE', 'DEU',
    'DNK', 'ESP', 'EST', 'FIN', 'FRA', 'GBR', 'GRC', 'HUN', 'IRL', 'ISL',
    'ISR', 'ITA', 'JPN', 'KOR', 'LTU', 'LUX', 'LVA', 'MEX', 'NLD', 'NOR',
    'NZL', 'POL', 'PRT', 'SVK', 'SVN', 'SWE', 'TUR', 'USA',
}

# Level variables get first-differenced; growth variables stay as-is
LEVEL_OUTCOMES = {
    'gross_fixed_investment_gdp': 'Δ Investment/GDP',
    'mpk_proxy': 'Δ MPK',
}
GROWTH_OUTCOMES = {
    'delta_log_kl': 'Cumul. Δlog K/L',
    'rgdp_growth': 'Cumul. GDP Growth',
    'delta_log_tfp': 'Cumul. Δlog TFP',
}


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def build_level_horizon(df, y_var, h):
    """
    For level variables: outcome = y_{t+h} - y_{t-1}
    This is the cumulative change from the pre-shock period.
    For h<0: outcome = y_{t+h} - y_{t-1} (pre-trend = change before shock)
    """
    df = df.sort_values(['iso3', 'year']).copy()
    y_baseline = df.groupby('iso3')[y_var].shift(1)  # y_{t-1}

    if h >= 0:
        y_forward = df.groupby('iso3')[y_var].shift(-h)  # y_{t+h}
    else:
        y_forward = df.groupby('iso3')[y_var].shift(abs(h))  # y_{t+h} where h<0

    col = f'd_{y_var}_h{h}'
    df[col] = y_forward - y_baseline
    return df, col


def build_growth_horizon(df, y_var, h):
    """
    For growth variables: cumulative sum from t to t+h (h>=0)
    or cumulative sum of h periods before t (h<0).
    """
    df = df.sort_values(['iso3', 'year']).copy()
    col = f'{y_var}_h{h}'

    if h == 0:
        df[col] = df[y_var]
    elif h > 0:
        df[col] = (
            df.groupby('iso3')[y_var]
            .transform(lambda s: s.rolling(window=h+1, min_periods=h+1).sum().shift(-h))
        )
    else:
        abs_h = abs(h)
        df[col] = (
            df.groupby('iso3')[y_var]
            .transform(lambda s: s.rolling(window=abs_h, min_periods=abs_h).sum().shift(1))
        )
    return df, col


def run_lp(df, y_col, x_vars, controls, key_var):
    """Run LP, return key_var results."""
    all_vars = [y_col] + x_vars + controls + ['iso3', 'year']
    avail = [c for c in all_vars if c in df.columns]
    sub = df[avail].dropna()
    actual_x = [v for v in x_vars + controls if v in sub.columns]

    if len(sub) < 30:
        return None

    gls = PanelGLS()
    try:
        gls.fit(sub[y_col].values, sub[actual_x].values,
                sub['iso3'].values, sub['year'].values)
    except Exception:
        return None

    idx = actual_x.index(key_var) if key_var in actual_x else None
    if idx is None:
        return None

    return {
        'coef': gls.beta[idx],
        'se': gls.se[idx],
        'p': gls.pvalues[idx],
        'ci_lo': gls.beta[idx] - 1.96 * gls.se[idx],
        'ci_hi': gls.beta[idx] + 1.96 * gls.se[idx],
        'n_obs': gls.n_obs,
        'r_squared': gls.r_squared,
    }


def make_ascii_irf(irf_data, title, width=60):
    """ASCII impulse response plot."""
    lines = [title, '=' * len(title), '']

    all_vals = []
    for h, pt in sorted(irf_data.items()):
        if pt and not np.isnan(pt.get('coef', np.nan)):
            all_vals.extend([pt['ci_lo'], pt['ci_hi']])
    if not all_vals:
        return title + '\n  (no data)\n'

    vmin, vmax = min(all_vals), max(all_vals)
    vmin = min(vmin, 0)
    vmax = max(vmax, 0)
    span = vmax - vmin if vmax > vmin else 1

    def pos(val):
        return int((val - vmin) / span * (width - 1))

    zero_pos = pos(0)

    lines.append(f'  h  {"coef":>12s} {"SE":>8s} {"p":>6s}  {"N":>5s}  Plot')
    lines.append(f'  -  {"----":>12s} {"--":>8s} {"---":>6s}  {"---":>5s}  ' + '-' * width)

    for h, pt in sorted(irf_data.items()):
        if pt is None or np.isnan(pt.get('coef', np.nan)):
            lines.append(f' {h:2d}  {"n/a":>12s}')
            continue

        sig = stars(pt['p'])
        coef_str = f"{pt['coef']:.5f}{sig}"

        plot = [' '] * width
        plot[zero_pos] = '|'

        c_pos = min(max(0, pos(pt['coef'])), width - 1)
        lo_pos = max(0, pos(pt['ci_lo']))
        hi_pos = min(width - 1, pos(pt['ci_hi']))

        for i in range(lo_pos, hi_pos + 1):
            plot[i] = '-'
        plot[c_pos] = '*' if pt['p'] < 0.05 else 'o' if pt['p'] < 0.1 else '·'
        if not (lo_pos <= zero_pos <= hi_pos):
            plot[zero_pos] = '|'

        lines.append(f' {h:2d}  {coef_str:>12s} {pt["se"]:8.5f} {pt["p"]:6.4f}  {pt["n_obs"]:5d}  {"".join(plot)}')

    lines.append(f'     {"":>12s} {"":>8s} {"":>6s}  {"":>5s}  ' + '-' * width)
    lines.append(f'     {vmin:>10.4f}{" " * (width - 20)}{vmax:>10.4f}')
    lines.append(f'  (* p<0.05, o p<0.1, · p>0.1, | = zero)')
    lines.append('')
    return '\n'.join(lines)


def run_full_irf(df, y_var, y_label, is_level, key_var='log_predicted_demo_inflows'):
    """Run full IRF with pre-trends for one outcome."""
    all_horizons = PRE_HORIZONS + list(range(MAX_HORIZON + 1))
    irf = {}

    for h in all_horizons:
        if is_level:
            df_h, y_col = build_level_horizon(df, y_var, h)
        else:
            df_h, y_col = build_growth_horizon(df, y_var, h)

        result = run_lp(df_h, y_col, [key_var], CONTROLS, key_var)
        irf[h] = result

    return irf


def main():
    print("=" * 70)
    print("PHASE 10: FIRST-DIFFERENCED LOCAL PROJECTIONS")
    print("The honest pre-trend fix")
    print("=" * 70)

    df = pd.read_csv(DATA / "deepening_panel.csv")
    df_oecd = df[df['iso3'].isin(OECD)].copy()
    print(f"OECD panel: {len(df_oecd)} obs, {df_oecd['iso3'].nunique()} countries")

    all_irfs = {}
    all_plots = []

    # ── Part A: Level variables (first-differenced) ──
    print("\n" + "=" * 70)
    print("LEVEL VARIABLES — FIRST DIFFERENCED: y_{t+h} - y_{t-1}")
    print("=" * 70)

    for y_var, y_label in LEVEL_OUTCOMES.items():
        if y_var not in df_oecd.columns:
            continue

        print(f"\n--- {y_label} ---")
        irf = run_full_irf(df_oecd, y_var, y_label, is_level=True)
        all_irfs[y_var] = irf

        plot = make_ascii_irf(irf, f'{y_label} (first-diff): Demo Inflows (OECD)')
        print(plot)
        all_plots.append(plot)

    # ── Part B: Growth variables (already differenced) ──
    print("\n" + "=" * 70)
    print("GROWTH VARIABLES — CUMULATED (already in differences)")
    print("=" * 70)

    for y_var, y_label in GROWTH_OUTCOMES.items():
        if y_var not in df_oecd.columns:
            continue

        print(f"\n--- {y_label} ---")
        irf = run_full_irf(df_oecd, y_var, y_label, is_level=False)
        all_irfs[y_var] = irf

        plot = make_ascii_irf(irf, f'{y_label}: Demo Inflows (OECD)')
        print(plot)
        all_plots.append(plot)

    # ── Part C: Placebo comparison for first-differenced investment ──
    print("\n" + "=" * 70)
    print("PLACEBO CHECK: First-Differenced Investment/GDP")
    print("=" * 70)

    # Construct placebo
    from pathlib import Path as P
    grav = pd.read_csv(ROOT_DIR / "gravity_bilateral" / "output" / "tables" / "gravity_results.csv")
    model_2c = grav[grav['model'] == '2c: Gravity + Demographics + KAOPEN interactions']
    gravity_only_vars = ['log_dist', 'contiguity', 'common_lang_official',
                         'colonial_ties', 'log_gdp_product']
    gravity_coeffs = {}
    for _, row in model_2c.iterrows():
        if row['variable'] in gravity_only_vars:
            gravity_coeffs[row['variable']] = row['coefficient']

    bp = pd.read_csv(ROOT_DIR / "gravity_bilateral" / "data" / "processed" / "bilateral_panel.csv")
    valid = bp.dropna(subset=[v for v in gravity_only_vars if v in bp.columns]).copy()
    valid['predicted_placebo'] = sum(gravity_coeffs.get(v, 0) * valid[v]
                                      for v in gravity_only_vars if v in valid.columns)
    valid['predicted_placebo_level'] = np.exp(valid['predicted_placebo'])
    placebo = (valid.groupby(['iso_d', 'year'])
               .agg(predicted_placebo_inflows=('predicted_placebo_level', 'sum'))
               .reset_index().rename(columns={'iso_d': 'iso3'}))
    placebo['log_predicted_placebo_inflows'] = np.log(
        placebo['predicted_placebo_inflows'].clip(lower=1e-6))

    df_oecd_p = df_oecd.merge(placebo[['iso3', 'year', 'log_predicted_placebo_inflows']],
                               on=['iso3', 'year'], how='left')

    if 'gross_fixed_investment_gdp' in df_oecd_p.columns:
        print("\n--- Demographic instrument ---")
        irf_demo = run_full_irf(df_oecd_p, 'gross_fixed_investment_gdp',
                                 'Δ Invest/GDP', is_level=True,
                                 key_var='log_predicted_demo_inflows')
        plot_d = make_ascii_irf(irf_demo, 'DEMOGRAPHIC → Δ Investment/GDP (OECD)')
        print(plot_d)
        all_plots.append(plot_d)

        print("\n--- Placebo instrument (gravity-only) ---")
        irf_plac = run_full_irf(df_oecd_p, 'gross_fixed_investment_gdp',
                                 'Δ Invest/GDP', is_level=True,
                                 key_var='log_predicted_placebo_inflows')
        plot_p = make_ascii_irf(irf_plac, 'PLACEBO (gravity-only) → Δ Investment/GDP (OECD)')
        print(plot_p)
        all_plots.append(plot_p)

        all_irfs['invest_placebo'] = irf_plac

    # ── Pre-trend assessment ──
    print("\n" + "=" * 70)
    print("PRE-TREND ASSESSMENT (First-Differenced)")
    print("=" * 70)

    for y_var in list(LEVEL_OUTCOMES.keys()) + list(GROWTH_OUTCOMES.keys()):
        label = LEVEL_OUTCOMES.get(y_var, GROWTH_OUTCOMES.get(y_var, y_var))
        irf = all_irfs.get(y_var, {})
        if not irf:
            continue

        pre_sig = sum(1 for h in PRE_HORIZONS
                      if irf.get(h) and irf[h]['p'] < 0.1)
        pre_total = sum(1 for h in PRE_HORIZONS if irf.get(h))
        status = 'CLEAN' if pre_sig == 0 else f'{pre_sig}/{pre_total} sig'

        h2 = irf.get(2)
        h2_str = f"β={h2['coef']:.4f} (p={h2['p']:.4f}) {stars(h2['p'])}" if h2 else 'n/a'

        print(f"  {label:30s}  pre-trend: {status:12s}  h=2: {h2_str}")

    # Check placebo
    plac_irf = all_irfs.get('invest_placebo', {})
    if plac_irf:
        h2_p = plac_irf.get(2)
        h2_str = f"β={h2_p['coef']:.4f} (p={h2_p['p']:.4f}) {stars(h2_p['p'])}" if h2_p else 'n/a'
        print(f"  {'Placebo → Δ Invest/GDP':30s}  {'':12s}  h=2: {h2_str}")

    # ── Save ──
    print("\n--- Saving results ---")

    rows = []
    for y_var, irf in all_irfs.items():
        for h, pt in irf.items():
            if pt:
                rows.append({'outcome': y_var, 'horizon': h, **pt})

    pd.DataFrame(rows).to_csv(OUT_TABLES / "firstdiff_lp_results.csv", index=False)

    with open(OUT_TABLES / "firstdiff_lp.md", 'w') as f:
        f.write("# Table 10: First-Differenced Local Projections (OECD)\n\n")
        f.write("Level variables show y_{t+h} - y_{t-1} (cumulative change from pre-shock).\n")
        f.write("Growth variables show cumulative sum from t to t+h.\n")
        f.write("Pre-trend test: h<0 coefficients should be null.\n\n")

        all_outcomes = {**LEVEL_OUTCOMES, **GROWTH_OUTCOMES}
        f.write("## Summary\n\n")
        f.write("| Outcome | h=-3 | h=-2 | h=-1 | h=0 | h=1 | h=2 | h=3 | h=4 | h=5 |\n")
        f.write("|---------|------|------|------|-----|-----|-----|-----|-----|-----|\n")

        for y_var, y_label in all_outcomes.items():
            irf = all_irfs.get(y_var, {})
            cells = []
            for h in PRE_HORIZONS + list(range(MAX_HORIZON + 1)):
                pt = irf.get(h)
                if pt and not np.isnan(pt['coef']):
                    cells.append(f"{pt['coef']:.4f}{stars(pt['p'])}")
                else:
                    cells.append('')
            f.write(f"| {y_label} | {' | '.join(cells)} |\n")

        # Placebo row
        if plac_irf:
            cells = []
            for h in PRE_HORIZONS + list(range(MAX_HORIZON + 1)):
                pt = plac_irf.get(h)
                if pt and not np.isnan(pt['coef']):
                    cells.append(f"{pt['coef']:.4f}{stars(pt['p'])}")
                else:
                    cells.append('')
            f.write(f"| Placebo → Δ Invest | {' | '.join(cells)} |\n")

        f.write("\n")

        f.write("## IRF Plots\n\n```\n")
        for plot in all_plots:
            f.write(plot + '\n')
        f.write("```\n")

        f.write("\n*p<0.1, **p<0.05, ***p<0.01. OECD subsample. PanelGLS with AR(1).*\n")

    print(f"Saved: {OUT_TABLES / 'firstdiff_lp_results.csv'}")
    print(f"Saved: {OUT_TABLES / 'firstdiff_lp.md'}")
    print("\nPhase 10 complete.")


if __name__ == '__main__':
    main()
